/*
 * gzip.c - hooks for compression of packets
 *
 * This file is part of the SSH Library
 *
 * Copyright (c) 2003      by Aris Adamantiadis
 * Copyright (c) 2009      by Andreas Schneider <asn@cryptomilk.org>
 *
 * The SSH Library is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 *
 * The SSH Library is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
 * License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with the SSH Library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
 * MA 02111-1307, USA.
 */

#include "config.h"

#include <stdlib.h>
#include <string.h>
#include <zlib.h>

#include "libssh/buffer.h"
#include "libssh/crypto.h"
#include "libssh/priv.h"
#include "libssh/session.h"

#ifndef BLOCKSIZE
#define BLOCKSIZE 4092
#endif

static z_stream *
initcompress(ssh_session session, int level)
{
    z_stream *stream = NULL;
    int status;

    stream = calloc(1, sizeof(z_stream));
    if (stream == NULL) {
        return NULL;
    }

    status = deflateInit(stream, level);
    if (status != Z_OK) {
        SAFE_FREE(stream);
        ssh_set_error(session,
                      SSH_FATAL,
                      "status %d initialising zlib deflate",
                      status);
        return NULL;
    }

    return stream;
}

static ssh_buffer
gzip_compress(ssh_session session, ssh_buffer source, int level)
{
    struct ssh_crypto_struct *crypto = NULL;
    z_stream *zout = NULL;
    void *in_ptr = ssh_buffer_get(source);
    uint32_t in_size = ssh_buffer_get_len(source);
    ssh_buffer dest = NULL;
    unsigned char out_buf[BLOCKSIZE] = {0};
    uint32_t len;
    int status;

    crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_OUT);
    if (crypto == NULL) {
        return NULL;
    }
    zout = crypto->compress_out_ctx;
    if (zout == NULL) {
        zout = crypto->compress_out_ctx = initcompress(session, level);
        if (zout == NULL) {
            return NULL;
        }
    }

    dest = ssh_buffer_new();
    if (dest == NULL) {
        return NULL;
    }

    zout->next_out = out_buf;
    zout->next_in = in_ptr;
    zout->avail_in = in_size;
    do {
        zout->avail_out = BLOCKSIZE;
        status = deflate(zout, Z_PARTIAL_FLUSH);
        if (status != Z_OK) {
            SSH_BUFFER_FREE(dest);
            ssh_set_error(session,
                          SSH_FATAL,
                          "status %d deflating zlib packet",
                          status);
            return NULL;
        }
        len = BLOCKSIZE - zout->avail_out;
        if (ssh_buffer_add_data(dest, out_buf, len) < 0) {
            SSH_BUFFER_FREE(dest);
            return NULL;
        }
        zout->next_out = out_buf;
    } while (zout->avail_out == 0);

    return dest;
}

int
compress_buffer(ssh_session session, ssh_buffer buf)
{
    ssh_buffer dest = NULL;
    int rv;

    dest = gzip_compress(session, buf, session->opts.compressionlevel);
    if (dest == NULL) {
        return -1;
    }

    if (ssh_buffer_reinit(buf) < 0) {
        SSH_BUFFER_FREE(dest);
        return -1;
    }

    rv = ssh_buffer_add_data(buf,
                             ssh_buffer_get(dest),
                             ssh_buffer_get_len(dest));
    if (rv < 0) {
        SSH_BUFFER_FREE(dest);
        return -1;
    }

    SSH_BUFFER_FREE(dest);
    return 0;
}

/* decompression */

static z_stream *
initdecompress(ssh_session session)
{
    z_stream *stream = NULL;
    int status;

    stream = calloc(1, sizeof(z_stream));
    if (stream == NULL) {
        return NULL;
    }

    status = inflateInit(stream);
    if (status != Z_OK) {
        SAFE_FREE(stream);
        ssh_set_error(session,
                      SSH_FATAL,
                      "Status = %d initiating inflate context!",
                      status);
        return NULL;
    }

    return stream;
}

static ssh_buffer
gzip_decompress(ssh_session session, ssh_buffer source, size_t maxlen)
{
    struct ssh_crypto_struct *crypto = NULL;
    z_stream *zin = NULL;
    void *in_ptr = ssh_buffer_get(source);
    uint32_t in_size = ssh_buffer_get_len(source);
    unsigned char out_buf[BLOCKSIZE] = {0};
    ssh_buffer dest = NULL;
    uint32_t len;
    int status;

    crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_IN);
    if (crypto == NULL) {
        return NULL;
    }

    zin = crypto->compress_in_ctx;
    if (zin == NULL) {
        zin = crypto->compress_in_ctx = initdecompress(session);
        if (zin == NULL) {
            return NULL;
        }
    }

    dest = ssh_buffer_new();
    if (dest == NULL) {
        return NULL;
    }

    zin->next_out = out_buf;
    zin->next_in = in_ptr;
    zin->avail_in = in_size;

    do {
        zin->avail_out = BLOCKSIZE;
        status = inflate(zin, Z_PARTIAL_FLUSH);
        if (status != Z_OK && status != Z_BUF_ERROR) {
            ssh_set_error(session,
                          SSH_FATAL,
                          "status %d inflating zlib packet",
                          status);
            SSH_BUFFER_FREE(dest);
            return NULL;
        }

        len = BLOCKSIZE - zin->avail_out;
        if (ssh_buffer_add_data(dest, out_buf, len) < 0) {
            SSH_BUFFER_FREE(dest);
            return NULL;
        }
        if (ssh_buffer_get_len(dest) > maxlen) {
            /* Size of packet exceeded, avoid a denial of service attack */
            SSH_BUFFER_FREE(dest);
            return NULL;
        }
        zin->next_out = out_buf;
    } while (zin->avail_out == 0);

    return dest;
}

int
decompress_buffer(ssh_session session, ssh_buffer buf, size_t maxlen)
{
    ssh_buffer dest = NULL;
    int rv;

    dest = gzip_decompress(session, buf, maxlen);
    if (dest == NULL) {
        return -1;
    }

    if (ssh_buffer_reinit(buf) < 0) {
        SSH_BUFFER_FREE(dest);
        return -1;
    }

    rv = ssh_buffer_add_data(buf,
                             ssh_buffer_get(dest),
                             ssh_buffer_get_len(dest));
    if (rv < 0) {
        SSH_BUFFER_FREE(dest);
        return -1;
    }

    SSH_BUFFER_FREE(dest);
    return 0;
}
